da13a1
@@ -22,11 +22,13 @@
 import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping.STRING_GROUP;
 import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping.VOID_GROUP;
 
+import java.text.SimpleDateFormat;
 import java.util.Calendar;
 import java.util.Date;
 
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -44,32 +46,61 @@
  *
  */
 @Description(name = "add_months",
-    value = "_FUNC_(start_date, num_months) - Returns the date that is num_months after start_date.",
-    extended = "start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or"
-        + " 'yyyy-MM-dd'. num_months is a number. The time part of start_date is "
-        + "ignored.\n"
-        + "Example:\n " + " > SELECT _FUNC_('2009-08-31', 1) FROM src LIMIT 1;\n" + " '2009-09-30'")
+    value = "_FUNC_(start_date, num_months, output_date_format) - "
+        + "Returns the date that is num_months after start_date.",
+    extended = "start_date is a string or timestamp indicating a valid date. "
+        + "num_months is a number. output_date_format is an optional String which specifies the format for output.\n"
+        + "The default output format is 'YYYY-MM-dd'.\n"
+        + "Example:\n  > SELECT _FUNC_('2009-08-31', 1) FROM src LIMIT 1;\n" + " '2009-09-30'."
+        + "\n  > SELECT _FUNC_('2017-12-31 14:15:16', 2, 'YYYY-MM-dd HH:mm:ss') LIMIT 1;\n"
+        + "'2018-02-28 14:15:16'.\n")
 @NDV(maxNdv = 250) // 250 seems to be reasonable upper limit for this
 public class GenericUDFAddMonths extends GenericUDF {
-  private transient Converter[] converters = new Converter[2];
-  private transient PrimitiveCategory[] inputTypes = new PrimitiveCategory[2];
-  private final Calendar calendar = Calendar.getInstance();
+  private transient Converter[] tsConverters = new Converter[3];
+  private transient PrimitiveCategory[] tsInputTypes = new PrimitiveCategory[3];
+  private transient Converter[] dtConverters = new Converter[3];
+  private transient PrimitiveCategory[] dtInputTypes = new PrimitiveCategory[3];
   private final Text output = new Text();
+  private transient SimpleDateFormat formatter = null;
+  private final Calendar calendar = Calendar.getInstance();
   private transient Integer numMonthsConst;
   private transient boolean isNumMonthsConst;
 
   @Override
   public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
-    checkArgsSize(arguments, 2, 2);
+    checkArgsSize(arguments, 2, 3);
 
     checkArgPrimitive(arguments, 0);
     checkArgPrimitive(arguments, 1);
 
-    checkArgGroups(arguments, 0, inputTypes, STRING_GROUP, DATE_GROUP, VOID_GROUP);
-    checkArgGroups(arguments, 1, inputTypes, NUMERIC_GROUP, VOID_GROUP);
+    if (arguments.length == 3) {
+      if (arguments[2] instanceof ConstantObjectInspector) {
+        checkArgPrimitive(arguments, 2);
+        checkArgGroups(arguments, 2, tsInputTypes, STRING_GROUP);
+        String fmtStr = getConstantStringValue(arguments, 2);
+        if (fmtStr != null) {
+          formatter = new SimpleDateFormat(fmtStr);
+        }
+      } else {
+        throw new UDFArgumentTypeException(2, getFuncName() + " only takes constant as "
+            + getArgOrder(2) + " argument");
+      }
+    }
+    if (formatter == null) {
+      //If the DateFormat is not provided by the user or is invalid, use the default format YYYY-MM-dd
+      formatter = DateUtils.getDateFormat();
+    }
+
+    // the function should support both short date and full timestamp format
+    // time part of the timestamp should not be skipped
+    checkArgGroups(arguments, 0, tsInputTypes, STRING_GROUP, DATE_GROUP, VOID_GROUP);
+    checkArgGroups(arguments, 0, dtInputTypes, STRING_GROUP, DATE_GROUP, VOID_GROUP);
 
-    obtainDateConverter(arguments, 0, inputTypes, converters);
-    obtainIntConverter(arguments, 1, inputTypes, converters);
+    obtainTimestampConverter(arguments, 0, tsInputTypes, tsConverters);
+    obtainDateConverter(arguments, 0, dtInputTypes, dtConverters);
+
+    checkArgGroups(arguments, 1, tsInputTypes, NUMERIC_GROUP, VOID_GROUP);
+    obtainIntConverter(arguments, 1, tsInputTypes, tsConverters);
 
     if (arguments[1] instanceof ConstantObjectInspector) {
       numMonthsConst = getConstantIntValue(arguments, 1);
@@ -86,7 +117,7 @@
public Object evaluate(DeferredObject[] arguments) throws HiveException {
     if (isNumMonthsConst) {
       numMonthV = numMonthsConst;
     } else {
-      numMonthV = getIntValue(arguments, 1, converters);
+      numMonthV = getIntValue(arguments, 1, tsConverters);
     }
 
     if (numMonthV == null) {
@@ -94,14 +125,22 @@
public Object evaluate(DeferredObject[] arguments) throws HiveException {
     }
 
     int numMonthInt = numMonthV.intValue();
-    Date date = getDateValue(arguments, 0, inputTypes, converters);
+
+    // the function should support both short date and full timestamp format
+    // time part of the timestamp should not be skipped
+    Date date = getTimestampValue(arguments, 0, tsConverters);
     if (date == null) {
-      return null;
+      date = getDateValue(arguments, 0, dtInputTypes, dtConverters);
+      if (date == null) {
+        return null;
+      }
     }
 
     addMonth(date, numMonthInt);
     Date newDate = calendar.getTime();
-    output.set(DateUtils.getDateFormat().format(newDate));
+    String res = formatter.format(newDate);
+
+    output.set(res);
     return output;
   }
 
